{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using BERT with `Simple Transformers`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install simpletransformers based on the Transformers library by HuggingFace:  https://github.com/ThilinaRajapakse/simpletransformers/\n",
    "\n",
    "**BinaryClassification**:  https://towardsdatascience.com/simple-transformers-introducing-the-easiest-bert-roberta-xlnet-and-xlm-library-58bf8c59b2a3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q boto3\n",
    "!pip install -q scikit-learn==0.20.3\n",
    "!pip install -q simpletransformers==0.22.1\n",
    "!pip install -q tensorboardx==2.0\n",
    "!pip install -q torch==1.4.0 torchvision==0.5.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download the Data Locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp 's3://{bucket}/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz' ./data/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "df = pd.read_csv('./data/amazon_reviews_us_Digital_Software_v1_00.tsv.gz', \n",
    "                 delimiter='\\t', \n",
    "                 quoting=csv.QUOTE_NONE,\n",
    "                 compression='gzip')\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Enrich the Data with `is_positive_sentiment` Column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# df['is_positive_sentiment'] = (df['star_rating'] >= 4).astype(int)\n",
    "# df.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adapt the Dataset to Simple Transformers Convention.\n",
    "\n",
    "By `simpletransformer` convention, the dataframe must have 2 columns:\n",
    "* `text`\n",
    "* `labels`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "df_bert = df[['review_body', 'star_rating']]\n",
    "df_bert.columns = ['text', 'labels']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start Labels at 0 instead of 1\n",
    "SimpleTransformers requires that our labels start with 0:  https://medium.com/swlh/simple-transformers-multi-class-text-classification-with-bert-roberta-xlnet-xlm-and-8b585000ce3a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_bert['labels'] = df_bert['labels'] - 1\n",
    "df_bert.head(500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# To Lower the Training Time, Use a Subset of the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_bert = df_bert[:2000]\n",
    "df_bert.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Split the Data into `train`, `validation`, and `test`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "df_bert_train, df_bert_holdout = train_test_split(df_bert, test_size=0.40)\n",
    "df_bert_validation, df_bert_test = train_test_split(df_bert_holdout, test_size=0.50)\n",
    "\n",
    "print(df_bert_train.shape)\n",
    "print(df_bert_validation.shape)\n",
    "print(df_bert_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the Classification Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from simpletransformers.classification import ClassificationModel\n",
    "\n",
    "# args = {\n",
    "#    'output_dir': 'outputs/',\n",
    "#    'cache_dir': 'cache/',\n",
    "#    'fp16': False,\n",
    "#    'max_seq_length': 128,\n",
    "#    'train_batch_size': 8,\n",
    "#    'eval_batch_size': 8,\n",
    "#    'gradient_accumulation_steps': 1,\n",
    "#    'num_train_epochs': 1,\n",
    "#    'weight_decay': 0,\n",
    "#    'learning_rate': 3e-5,\n",
    "#    'adam_epsilon': 1e-8,\n",
    "#    'warmup_ratio': 0.06,\n",
    "#    'warmup_steps': 0,\n",
    "#    'max_grad_norm': 1.0,\n",
    "#    'logging_steps': 50,\n",
    "#    'evaluate_during_training': False,\n",
    "#    'save_steps': 2000,\n",
    "#    'eval_all_checkpoints': True,\n",
    "#    'use_tensorboard': True,\n",
    "#    'tensorboard_dir': 'tensorboard',\n",
    "#    'overwrite_output_dir': True,\n",
    "#    'reprocess_input_data': False,\n",
    "# }\n",
    "\n",
    "train_args={\n",
    "    'reprocess_input_data': True,\n",
    "    'overwrite_output_dir': True,\n",
    "    'num_train_epochs': 3,\n",
    "}\n",
    "\n",
    "bert_model = ClassificationModel(model_type='distilbert', # bert, distilbert, etc, etc.\n",
    "                                 model_name='distilbert-base-cased',\n",
    "                                 args=train_args,\n",
    "                                 use_cuda=False,\n",
    "                                 num_labels=5)\n",
    "\n",
    "bert_model.train_model(train_df=df_bert_train,\n",
    "                       eval_df=df_bert_validation,\n",
    "                       show_running_loss=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate the Model Using the `test` Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sklearn\n",
    "\n",
    "result, model_outputs, wrong_predictions = bert_model.eval_model(eval_df=df_bert_test, acc=sklearn.metrics.accuracy_score)\n",
    "\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show the Bad Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Number of wrong predictions: {}'.format(len(wrong_predictions)))\n",
    "print('\\n')\n",
    "\n",
    "for prediction in wrong_predictions:\n",
    "    print(prediction.text_a)\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate the Accuracy and Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "preds_test, preds_raw_outputs = bert_model.predict(df_bert_test['text'].tolist())\n",
    "preds_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test = df_bert_test['labels']\n",
    "y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, precision_score, classification_report, confusion_matrix\n",
    "\n",
    "print('Test Accuracy: ', accuracy_score(y_test, preds_test))\n",
    "print('Test Precision: ', precision_score(y_test, preds_test, average=None))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perform Ad-Hoc Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, raw_outputs = bert_model.predict([\"\"\"I really enjoyed this item.  I highly recommend it.\"\"\"])\n",
    "\n",
    "print('Predictions: {}'.format(predictions))\n",
    "print('Raw outputs: {}'.format(raw_outputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, raw_outputs = bert_model.predict([\"\"\"This item is awful and terrible.\"\"\"])\n",
    "\n",
    "print('Predictions: {}'.format(predictions))\n",
    "print('Raw outputs: {}'.format(raw_outputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
